{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# FFI 函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gc\n",
    "import ctypes\n",
    "import numpy as np\n",
    "from tvm import ffi as tvm_ffi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `echo`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fecho = tvm_ffi.get_global_func(\"testing.echo\")\n",
    "assert isinstance(fecho, tvm_ffi.Function)\n",
    "# test each type\n",
    "assert fecho(None) is None\n",
    "\n",
    "# test bool\n",
    "bool_result = fecho(True)\n",
    "assert isinstance(bool_result, bool)\n",
    "assert bool_result is True\n",
    "bool_result = fecho(False)\n",
    "assert isinstance(bool_result, bool)\n",
    "assert bool_result is False\n",
    "\n",
    "# test int/float\n",
    "assert fecho(1) == 1\n",
    "assert fecho(1.2) == 1.2\n",
    "\n",
    "# test str\n",
    "str_result = fecho(\"hello\")\n",
    "assert isinstance(str_result, str)\n",
    "assert str_result == \"hello\"\n",
    "assert isinstance(str_result, tvm_ffi.String)\n",
    "\n",
    "# test bytes\n",
    "bytes_result = fecho(b\"abc\")\n",
    "assert isinstance(bytes_result, bytes)\n",
    "assert bytes_result == b\"abc\"\n",
    "assert isinstance(bytes_result, tvm_ffi.Bytes)\n",
    "\n",
    "# test dtype\n",
    "dtype_result = fecho(tvm_ffi.dtype(\"float32\"))\n",
    "assert isinstance(dtype_result, tvm_ffi.dtype)\n",
    "assert dtype_result == tvm_ffi.dtype(\"float32\")\n",
    "\n",
    "# test device\n",
    "device_result = fecho(tvm_ffi.device(\"cuda:1\"))\n",
    "assert isinstance(device_result, tvm_ffi.Device)\n",
    "assert device_result.device_type == tvm_ffi.Device.kDLCUDA\n",
    "assert device_result.device_id == 1\n",
    "assert str(device_result) == \"cuda:1\"\n",
    "assert device_result.__repr__() == \"device(type='cuda', index=1)\"\n",
    "\n",
    "# test c_void_p\n",
    "c_void_p_result = fecho(ctypes.c_void_p(0x12345678))\n",
    "assert isinstance(c_void_p_result, ctypes.c_void_p)\n",
    "assert c_void_p_result.value == 0x12345678\n",
    "\n",
    "# test function: aka object\n",
    "fadd = tvm_ffi.convert(lambda a, b: a + b)\n",
    "fadd1 = fecho(fadd)\n",
    "assert fadd1(1, 2) == 3\n",
    "assert fadd1.same_as(fadd)\n",
    "\n",
    "def check_ndarray():\n",
    "    np_data = np.arange(10, dtype=\"int32\")\n",
    "    if not hasattr(np_data, \"__dlpack__\"):\n",
    "        return\n",
    "    # test NDArray\n",
    "    x = tvm_ffi.from_dlpack(np_data)\n",
    "    assert isinstance(x, tvm_ffi.NDArray)\n",
    "    nd_result = fecho(x)\n",
    "    assert isinstance(nd_result, tvm_ffi.NDArray)\n",
    "    assert nd_result.shape == (10,)\n",
    "    assert nd_result.dtype == tvm_ffi.dtype(\"int32\")\n",
    "    assert nd_result.device.device_type == tvm_ffi.Device.kDLCPU\n",
    "    assert nd_result.device.device_id == 0\n",
    "\n",
    "check_ndarray()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 返回原始字符"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert tvm_ffi.convert(lambda: \"hello\")() == \"hello\"\n",
    "assert tvm_ffi.convert(lambda: b\"hello\")() == b\"hello\"\n",
    "assert tvm_ffi.convert(lambda: bytearray(b\"hello\"))() == b\"hello\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Python 函数转换"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def add(a, b):\n",
    "    return a + b\n",
    "\n",
    "fadd = tvm_ffi.convert(add)\n",
    "assert isinstance(fadd, tvm_ffi.Function)\n",
    "assert fadd(1, 2) == 3\n",
    "\n",
    "def fapply(f, *args):\n",
    "    return f(*args)\n",
    "\n",
    "fapply = tvm_ffi.convert(fapply)\n",
    "assert fapply(add, 1, 3.3) == 4.3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 注册全局函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tvm_ffi.register_func(\"mytest.echo\")\n",
    "def echo(x):\n",
    "    return x\n",
    "\n",
    "f = tvm_ffi.get_global_func(\"mytest.echo\")\n",
    "assert f.same_as(echo)\n",
    "assert f(1) == 1\n",
    "\n",
    "assert \"mytest.echo\" in tvm_ffi.registry.list_global_func_names()\n",
    "\n",
    "tvm_ffi.registry.remove_global_func(\"mytest.echo\")\n",
    "assert \"mytest.echo\" not in tvm_ffi.registry.list_global_func_names()\n",
    "assert tvm_ffi.get_global_func(\"mytest.echo\", allow_missing=True) is None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 右值引用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "use_count = tvm_ffi.get_global_func(\"testing.object_use_count\")\n",
    "\n",
    "def callback(x, expected_count):\n",
    "    # The use count of TVM FFI objects is decremented as part of\n",
    "    # `ObjectRef.__del__`, which runs when the Python object is\n",
    "    # destructed.  However, Python object destruction is not\n",
    "    # deterministic, and even CPython's reference-counting is\n",
    "    # considered an implementation detail.  Therefore, to ensure\n",
    "    # correct results from this test, `gc.collect()` must be\n",
    "    # explicitly called.\n",
    "    gc.collect()\n",
    "    assert expected_count == use_count(x)\n",
    "    return x._move()\n",
    "\n",
    "f = tvm_ffi.convert(callback)\n",
    "\n",
    "def check0():\n",
    "    x = tvm_ffi.convert([1, 2])\n",
    "    assert use_count(x) == 1\n",
    "    f(x, 2)\n",
    "    y = f(x._move(), 1)\n",
    "    assert x.__ctypes_handle__().value == None\n",
    "\n",
    "def check1():\n",
    "    x = tvm_ffi.convert([1, 2])\n",
    "    assert use_count(x) == 1\n",
    "    y = f(x, 2)\n",
    "    z = f(x._move(), 2)\n",
    "    assert x.__ctypes_handle__().value == None\n",
    "    assert y.__ctypes_handle__().value is not None\n",
    "\n",
    "check0()\n",
    "check1()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
